#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from __future__ import division, with_statement
'''
Copyright 2015, 陈同 (chentong_biology@163.com).  
===========================================================
'''
__author__ = 'chentong & ct586[9]'
__author_email__ = 'chentong_biology@163.com'
#=========================================================
desc = '''
Program description:
    This is designed to summarize reads distribution output by `STAR`.
'''

import sys
import os
from json import dumps as json_dumps
from time import localtime, strftime 
timeformat = "%Y-%m-%d %H:%M:%S"
from optparse import OptionParser as OP
import re
from tools import *
#from multiprocessing.dummy import Pool as ThreadPool

#from bs4 import BeautifulSoup
reload(sys)
sys.setdefaultencoding('utf8')

debug = 0

def fprint(content):
    """ 
    This is a Google style docs.

    Args:
        param1(str): this is the first param
        param2(int, optional): this is a second param
            
    Returns:
        bool: This is a description of what is returned
            
    Raises:
        KeyError: raises an exception))
    """
    print json_dumps(content,indent=1)

def cmdparameter(argv):
    if len(argv) == 1:
        global desc
        print >>sys.stderr, desc
        cmd = 'python ' + argv[0] + ' -h'
        os.system(cmd)
        sys.exit(1)
    usages = "%prog -f file"
    parser = OP(usage=usages)
    parser.add_option("-f", "--files", dest="filein",
        metavar="FILEIN", help="`,` or ` ` separated a list of files. *.Log.final.out generated by `STAR` during mapping")
    parser.add_option("-l", "--labels", dest="label",
        metavar="LABEL", help="`,` or ` ` separated a list of labels to label each file. It must have same order as files.")
    parser.add_option("-o", "--output-prefix", dest="out_prefix",
        help="The prefix of output files.")
    parser.add_option("-r", "--report-dir", dest="report_dir",
        default='report', help="Directory for report files. Default 'report'.")
    parser.add_option("-b", "--bigwig", dest="bigwig",
        default=False, action="store_true", help="Rsyncing bigwig files. Default FALSE.")
    parser.add_option("-R", "--report-sub-dir", dest="report_sub_dir",
        default='2_mapping_quality', help="Directory for saving report figures and tables. This dir will put under <report_dir>,  so only dir name is needed. Default '2_mapping_quality'.")
    parser.add_option("-d", "--doc-only", dest="doc_only",
        default=False, action="store_true", help="Specify to only generate doc.")
    parser.add_option("-n", "--number", dest="number", type="int", 
        default=40, help="Set the maximum allowed samples for barplot. Default 40.\
 If more than this number of samples are given, heatmap will be used instead.")
    parser.add_option("-v", "--verbose", dest="verbose",
        action="store_true", help="Show process information")
    parser.add_option("-D", "--debug", dest="debug",
        default=False, action="store_true", help="Debug the program")
    (options, args) = parser.parse_args(argv[1:])
    assert options.filein != None, "A filename needed for -i"
    return (options, args)
#--------------------------------------------------------------------

def readStarFinalOut(file):
    sumD = {}
    start = 0
    for line in open(file):
        line = line.strip()
        if not start:
            if not line:
                start = 1
            continue
        #------------------------------
        vertial_line = line.find('|')
        if vertial_line != -1:
            type = line[:vertial_line].strip()
            value = line[vertial_line+1:].strip()
            if value.find('%') == -1 and value.find('.') == -1:
                value = int(value)
            elif value.find('%') != -1:
                value = value.replace("%", '')
            sumD[type] = value
    if debug:
        print >>sys.stderr, sumD
    
    return sumD
#-, ------------------------
def plot(fileL):
    for file in fileL:
        cmd = "s-plot barPlot -f " + file
        os.system(cmd)
#--------------------------------------
def plot_melt(total_melt, nameL):
    x_level = ["'"+i+"'" for i in nameL]
    x_level = '"'+','.join(x_level)+'"'
    cmd = ["s-plot barPlot -m TRUE -a Sample -R 90 -B set -O 1 -w 20 -u 25 -f ", 
            total_melt, ' -k free_y -L', x_level, 
            ' -y  \'Reads count or relative percent\' -x \'Samples\' ']
    #print ' '.join(cmd)
    os.system(' '.join(cmd))
#--------------------------------------
def plot_heatmap(totalTable_cnt, totalTable_per, sample_cnt):
    height = sample_cnt // 2 + 7
    if height < 10:
        height = 10
    cmd = ["s-plot heatmapS -a TRUE -A 45 -b TRUE -R TRUE", 
            "-x white -y blue -u 18 -v 30 -F 12 -T 0.5 -V 0.5", 
            "-f ", totalTable_cnt, "-I Count"]
    os.system(' '.join(cmd))
    cmd = ["s-plot heatmapS -a TRUE -A 45 -b TRUE -R TRUE", 
            "-x white -y blue -u 18 -v 30 -F 12 -T 0.5 -V 0.5", 
            "-f ", totalTable_per, "-I Percent"]
    os.system(' '.join(cmd))

#---------------------------------------

def generateDoc(report_dir, report_sub_dir, totalTable_cnt, 
        totalTable_per, total_melt, curation_label, melt, 
        bigwig, labelL):
    dest_dir = report_dir+'/'+report_sub_dir+'/'
    os.system('mkdir -p '+dest_dir)
    if melt:
        pdf = total_melt+'.stackBars.pdf'
        copypdf(dest_dir, pdf)
    else:
        per_pdf = totalTable_per+'.heatmapS.pdf'
        cnt_pdf = totalTable_cnt+'.heatmapS.pdf'
        copypdf(dest_dir, per_pdf, cnt_pdf)
    copy(dest_dir, totalTable_cnt, totalTable_per)
    totalTable_cntNew = report_sub_dir+'/'+os.path.split(totalTable_cnt)[-1]
    totalTable_perNew = report_sub_dir+'/'+os.path.split(totalTable_per)[-1]


    print "\n## 序列比对质量总结 {#seq-map-summary-STAR}\n"
    curation_label = "Reads_map_summary"
    knitr_read_txt(report_dir,  curation_label)
    
    print """
全部样品总的比对率如 Figure \@ref(fig:map-percentage-all-sample) and \@ref(fig:map-count-all-sample) 所示，用于从整体评估样品的比对率和可用reads数目。

```{{r map-percentage-all-sample, fig.cap="Distribution of reads mapping percentage for all samples. The vertical line represents median of mapping percentage."}}
map_percentage_all_sample <- read.table("{per}", header=T, sep="\t", quote="")
Final_kept_reads <- map_percentage_all_sample$Final_kept_reads
hist(Final_kept_reads, breaks=10, xlab="Reads mapping percentage (%)", ylab="Sample count", main="")
abline(v=median(Final_kept_reads), col="red")
```

```{{r map-count-all-sample, fig.cap="Distribution of reads mapping count for all samples. The vertical line represents median of mapping count."}}
map_count_all_sample <- read.table("{cnt}", header=T, sep="\t", quote="")
Final_kept_reads <- map_count_all_sample$Final_kept_reads/1e6
hist(Final_kept_reads, breaks=30, xlab="Reads mapping count (million)", ylab="Sample count", main="")
abline(v=median(Final_kept_reads), col="red")
```

""".format(per=totalTable_perNew, cnt=totalTable_cntNew)

    print """
每个样品测序序列比对总结见 (Table \@ref(tab:read-map-sum-cnt-table), Table \@ref(tab:read-map-sum-per-table) and Figure \@ref(fig:read-map-sum-fig))。

```{r map-sum-describe}
map_sum_describe = "Symbol;Explanation
Total;预处理后用于比对的中的reads数 
Unique_map;在基因组上有唯一比对位置的reads数
Multi_map_to_multiple_loci;在基因组上有多个比对位置（不多于20）的reads数, 一般保留
Multi_map_to_too_many_loci;在基因组上有很多比对位置的reads数
Unmap;未比对回基因组或转录组的reads数
Unmap_dueto_mismatch;因为错配太多未能比对回基因组或转录组的reads数
Unmap_dueto_too_short;因为比对长度太短（比对长度少于序列长度的2/3）而未能比对回基因组或转录组的reads数
Unmap_dueto_other;其它未比对回基因组或转录组的reads数
Final_kept_reads;过滤后用于下游分析的reads数"

map_sum_describe_data <- read.table(text=map_sum_describe, sep=";", header=T, quote="")
knitr::kable(map_sum_describe_data,format.args=list(big.mark=","))
#pander::pandoc.table(map_sum_describe_data, big.mark=',', justify='right')
```

"""

    print "Reads比对数据的统计可以帮助判断测序的质量、准确度、有无偏好和异常等。如果过滤后用于下游分析的reads数偏少，则需要慎重考虑。对于未比对回去的reads需要考虑未比对回去的原因，区别对待。\n"

    print "原始数据或者PDF格式的文件可以点击XLS或PDF下载。\n"
    
    print "(ref:read-map-sum-cnt-table) Summary raw counts of mapped reads. [XLS]({})\n".format(totalTable_cntNew)
    
    print '''```{{r read-map-sum-cnt-table, results="asis"}}
map_table <- read.table("{totalTable_cntNew}", sep="\\t", header=T, row.names=1, quote="", comment="")
knitr::kable(map_table, booktabs=T, caption="(ref:read-map-sum-cnt-table)", format="pandoc", format.args=list(big.mark=","))
```
'''.format(totalTable_cntNew=totalTable_cntNew)

    print "(ref:read-map-sum-per-table) Percent of mapped reads relative to total reads (%). [XLS]({})\n".format(totalTable_perNew)
    
    print '''```{{r read-map-sum-per-table}}
map_table <- read.table("{totalTable_perNew}", sep="\\t", header=T, row.names=1, quote="", comment="")
knitr::kable(map_table, booktabs=T, caption="(ref:read-map-sum-per-table)")
```
'''.format(totalTable_perNew=totalTable_perNew)
    
    if melt:
        pdf = report_sub_dir+'/'+os.path.split(pdf)[-1]
        print "(ref:read-map-sum-fig) Summary of reads mapping status. 1 Million = 10^6^. \
[PDF]({})\n".format(pdf)

        png = pdf.replace('pdf', 'png')
        print '''```{{r read-map-sum-fig, fig.cap="(ref:read-map-sum-fig)"}}
knitr::include_graphics("{png}")
```
'''.format(png=png)
    else:
        cnt_pdf = report_sub_dir+'/'+os.path.split(cnt_pdf)[-1]
        cnt_png = cnt_pdf.replace('pdf', 'png')
        per_pdf = report_sub_dir+'/'+os.path.split(per_pdf)[-1]
        per_png = per_pdf.replace('pdf', 'png')

        print "(ref:read-map-sum-fig) Summary of reads mapping status. 1 Million = 10^6^. \
[PDF_cnt]({}) [PDF_percent]({})\n".format(cnt_pdf, per_pdf)

        print '''```{{r read-map-sum-fig, out.width="49%", fig.cap="(ref:read-map-sum-fig)"}}
knitr::include_graphics(c("{cnt_png}", "{per_png}"))
```
'''.format(cnt_png=cnt_png, per_png=per_png)
    
    if bigwig:
        print "\n## 序列比对结果文件 {#seq-map-result-STAR}\n"

        copy(report_dir+"/images/", "/MPATHB/self/resource/sample/ucsc.png")
        
        print """
序列比对结果文件为[`bigwig`](http://genome.ucsc.edu/goldenPath/help/bigWig.html)格式文件可以导入[IGV](http://software.broadinstitute.org/software/igv/), [UCSC](http://genome.ucsc.edu/cgi-bin/hgTracks)等基因组浏览器，可视化测序reads在基因区和基因组区域的分布。


```{r igv-ucsc-example}
knitr::include_graphics("images/ucsc.png")
```
"""
        bwL = ["Sample;Bigwig file"]
        for label in labelL:
            bw = label+'/'+label+'.Signal.UniqueMultiple.str1.out.bw'
            copy(dest_dir, bw)
            bw = report_sub_dir+'/'+os.path.split(bw)[-1]
            tmp = label+';['+label+']('+bw+')'
            bwL.append(tmp)
        bwL = '\n'.join(bwL)

        print '''
```{{r, results="markdown"}}
igv_ucsc_bw = "{}"
igv_ucsc_bw_mat <- read.table(text=igv_ucsc_bw, sep=";", header=T)
knitr::kable(igv_ucsc_bw_mat, format="markdown")
```

'''.format(bwL)
#--------------------------------

def sum_float(*percent):
    sum = 0
    for i in percent:
        i = float(i.replace('%', ''))
        sum += i

    sum = "{:.2f}".format(sum)
    return sum

#-------------------------------------------
def multi(total, per):
    return int(total * float(per.replace('%', ''))/100)


def main():
    options, args = cmdparameter(sys.argv)
    #-----------------------------------
    file = options.filein
    fileL = re.split(r'[, ]*', file.strip())
    sample_readin = len(fileL)
    label = options.label
    labelL = re.split(r'[, ]*', label.strip())
    bigwig = options.bigwig
    verbose = options.verbose
    op = options.out_prefix
    report_dir = options.report_dir
    report_sub_dir = options.report_sub_dir
    global debug
    debug = options.debug
    doc_only = options.doc_only
    num_samples_each_grp = options.number
    melt = 0
    if sample_readin <= num_samples_each_grp:
        melt = 1
    #-----------------------------------
    aDict = {}
    totalTable_cnt = op+".map_summary_cnt.xls"
    totalTable_per = op+".map_summary_percent.xls"
    total_melt = op+'.melt_summary.xls'
    curation_label = os.path.split(sys.argv[0])[-1].replace('.', '_')
    if doc_only:
        generateDoc(report_dir, report_sub_dir, totalTable_cnt, 
                totalTable_per, total_melt, curation_label, melt, bigwig, labelL)
        return 0

    totalTable_cnt_fh = open(totalTable_cnt, 'w')
    totalTable_per_fh = open(totalTable_per, 'w')
    print >>totalTable_cnt_fh, "Sample\tTotal\tUnique_map\tMulti_map_to_multiple_loci\tMulti_map_to_too_many_loci\tUnmap\tUnmap_dueto_mismatch\tUnmap_dueto_too_short\tUnmap_dueto_other\tFinal_kept_reads"
    print >>totalTable_per_fh, "Sample\tUnique_map\tMulti_map_to_multiple_loci\tMulti_map_to_too_many_loci\tUnmap\tUnmap_dueto_mismatch\tUnmap_dueto_too_short\tUnmap_dueto_other\tFinal_kept_reads"


    if melt:
        total_melt_fh = open(total_melt, 'w')
        print >>total_melt_fh, "Sample\tvariable\tvalue\tset"

    i = -1
    for file in fileL:
        i += 1
        #group = "Group_"+str(i // num_samples_each_grp+1)
        label = labelL[i]
        sumD = readStarFinalOut(file)
        total = sumD['Number of input reads']
        Unique_map = sumD['Uniquely mapped reads number']
        Unique_map_per = sumD['Uniquely mapped reads %']
        Multi_map_to_multiple_loci = sumD['Number of reads mapped to multiple loci']
        Multi_map_to_multiple_loci_per = sumD['% of reads mapped to multiple loci']
        Multi_map_to_too_many_loci = sumD['Number of reads mapped to too many loci']
        Multi_map_to_too_many_loci_per = sumD['% of reads mapped to too many loci']

        UnMap = total-Unique_map-Multi_map_to_multiple_loci-Multi_map_to_too_many_loci
        Mismtach_caused_unmap_per = sumD["% of reads unmapped: too many mismatches"]
        Short_caused_unmap_per    = sumD["% of reads unmapped: too short"]
        Other_caused_unmap_per    = sumD["% of reads unmapped: other"]
        Mismtach_caused_unmap = multi(total, Mismtach_caused_unmap_per)
        Short_caused_unmap = multi(total, Short_caused_unmap_per)
        Other_caused_unmap = multi(total, Other_caused_unmap_per)
        
        UnMap_per = sum_float(Mismtach_caused_unmap_per, Short_caused_unmap_per, Other_caused_unmap_per)

        Final_used = Unique_map+Multi_map_to_multiple_loci
        Final_used_per = "{:.2f}".format(Final_used/total*100)
        
        cntL = [label, total, Unique_map, 
            Multi_map_to_multiple_loci, Multi_map_to_too_many_loci, 
            UnMap, Mismtach_caused_unmap, Short_caused_unmap, 
            Other_caused_unmap, Final_used]

        percentL = [label, Unique_map_per, Multi_map_to_multiple_loci_per, 
            Multi_map_to_too_many_loci_per, UnMap_per, Mismtach_caused_unmap_per, 
            Short_caused_unmap_per, Other_caused_unmap_per, Final_used_per]
            
        cntL = [str(j) for j in cntL]
        percentL = [str(j) for j in percentL]

        print >>totalTable_cnt_fh, '\t'.join(cntL)
        print >>totalTable_per_fh, '\t'.join(percentL)
        
        if melt:
            lineL = [
               [label, "Unique_map", str(Unique_map), "Raw reads count"], 
               [label, "Multi_map_to_multiple_loci", str(Multi_map_to_multiple_loci),
                   "Raw reads count"], 
               [label, "Multi_map_to_too_many_loci", str(Multi_map_to_too_many_loci), 
                   "Raw reads count"], 
               [label, "Unmap_dueto_mismatch", str(Mismtach_caused_unmap), 
                   "Raw reads count"], 
               [label, "Unmap_dueto_too_short", str(Short_caused_unmap), 
                   "Raw reads count"], 
               [label, "Unmap_dueto_other", str(Other_caused_unmap), 
                   "Raw reads count"],
               [label, "Unique_map", str(Unique_map_per), "Relative percent"], 
               [label, "Multi_map_to_multiple_loci", 
                   str(Multi_map_to_multiple_loci_per), "Relative percent"], 
               [label, "Multi_map_to_too_many_loci", 
                   str(Multi_map_to_too_many_loci_per), "Relative percent"], 
               [label, "Unmap_dueto_mismatch", str(Mismtach_caused_unmap_per), 
                   "Relative percent"], 
               [label, "Unmap_dueto_too_short", str(Short_caused_unmap_per), 
                   "Relative percent"], 
               [label, "Unmap_dueto_other", str(Other_caused_unmap_per), 
                   "Relative percent"]
            ]
            print >>total_melt_fh, "\n".join(['\t'.join(j) for j in lineL])
                


    totalTable_cnt_fh.close()
    totalTable_per_fh.close()


    if melt:
        total_melt_fh.close()
        plot_melt(total_melt, labelL)
    else:
        plot_heatmap(totalTable_cnt, totalTable_per, sample_readin)

    generateDoc(report_dir, report_sub_dir, totalTable_cnt, 
            totalTable_per, total_melt, curation_label, melt, bigwig, labelL)
    ###--------multi-process------------------

if __name__ == '__main__':
    startTime = strftime(timeformat, localtime())
    main()
    endTime = strftime(timeformat, localtime())
    fh = open('python.log', 'a')
    print >>fh, "%s\n\tRun time : %s - %s " % \
        (' '.join(sys.argv), startTime, endTime)
    fh.close()
    ###---------profile the program---------
    #import profile
    #profile_output = sys.argv[0]+".prof.txt")
    #profile.run("main()", profile_output)
    #import pstats
    #p = pstats.Stats(profile_output)
    #p.sort_stats("time").print_stats()
    ###---------profile the program---------


